import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
from utils import evaluate_f1_score
from data_loader import load_PPI
import argparse
import numpy as np
import os

class GNNFiLMLayer(nn.Module):
    def __init__(self, in_size, out_size, etypes, dropout=0.1):
        super(GNNFiLMLayer, self).__init__()
        self.in_size = in_size
        self.out_size = out_size

        #weights for different types of edges
        self.W = nn.ModuleDict({
            name : nn.Linear(in_size, out_size, bias = False) for name in etypes
        })

        #hypernets to learn the affine functions for different types of edges
        self.film = nn.ModuleDict({
            name : nn.Linear(in_size, 2*out_size, bias = False) for name in etypes
        })

        #layernorm before each propogation
        self.layernorm = nn.LayerNorm(out_size)

        #dropout layer
        self.dropout = nn.Dropout(dropout)

    def forward(self, g, feat_dict):
        #the input graph is a multi-relational graph, so treated as hetero-graph.

        funcs = {} #message and reduce functions dict
        #for each type of edges, compute messages and reduce them all
        for srctype, etype, dsttype in g.canonical_etypes:
            messages = self.W[etype](feat_dict[srctype])  #apply W_l on src feature
            film_weights = self.film[etype](feat_dict[dsttype]) #use dst feature to compute affine function paras
            gamma = film_weights[:,:self.out_size]  #"gamma" for the affine function
            beta = film_weights[:,self.out_size:]  #"beta" for the affine function
            messages = gamma * messages + beta  #compute messages
            messages =  F.relu_(messages)
            g.nodes[srctype].data[etype] = messages   #store in ndata
            funcs[etype] = (fn.copy_u(etype, 'm'), fn.sum('m', 'h'))  #define message and reduce functions
        g.multi_update_all(funcs, 'sum') #update all, reduce by first type-wisely then across different types
        feat_dict={}
        for ntype in g.ntypes:
            feat_dict[ntype] = self.dropout(self.layernorm(g.nodes[ntype].data['h'])) #apply layernorm and dropout
        return feat_dict

class GNNFiLM(nn.Module):
    def __init__(self, etypes, in_size, hidden_size, out_size, num_layers, dropout=0.1):
        super(GNNFiLM, self).__init__()
        self.film_layers = nn.ModuleList()
        self.film_layers.append(
            GNNFiLMLayer(in_size, hidden_size, etypes, dropout)
        )
        for i in range(num_layers-1):
            self.film_layers.append(
                GNNFiLMLayer(hidden_size, hidden_size, etypes, dropout)
            )
        self.predict = nn.Linear(hidden_size, out_size, bias = True)

    def forward(self, g, out_key):
        h_dict = {ntype : g.nodes[ntype].data['feat'] for ntype in g.ntypes}   #prepare input feature dict
        for layer in self.film_layers:
            h_dict = layer(g, h_dict)
        h = self.predict(h_dict[out_key]) #use the final embed to predict, out_size = num_classes
        h = torch.sigmoid(h)
        return h

def main(args):
    # Step 1: Prepare graph data and retrieve train/validation/test dataloader ============================= #
    if args.gpu >= 0 and torch.cuda.is_available():
        device = 'cuda:{}'.format(args.gpu)
    else:
        device = 'cpu'

    if args.dataset == 'PPI':
        train_set, valid_set, test_set, etypes, in_size, out_size = load_PPI(args.batch_size, device)

    # Step 2: Create model and training components=========================================================== #
    model = GNNFiLM(etypes, in_size, args.hidden_size, out_size, args.num_layers).to(device)
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.step_size, gamma=args.gamma)

    # Step 4: training epoches ============================================================================== #
    lastf1 = 0
    cnt = 0
    best_val_f1 = 0
    for epoch in range(args.max_epoch):
        train_loss = []
        train_f1 = []
        val_loss = []
        val_f1 = []
        model.train()
        for batch in train_set:
            g = batch.graph
            g = g.to(device)
            logits = model.forward(g, '_N')
            labels = batch.label
            loss = criterion(logits, labels)
            f1 = evaluate_f1_score(logits.detach().cpu().numpy(), labels.detach().cpu().numpy())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())
            train_f1.append(f1)

        train_loss = np.mean(train_loss) 
        train_f1 = np.mean(train_f1)
        scheduler.step()

        model.eval()
        with torch.no_grad():
            for batch in valid_set:
                g = batch.graph
                g = g.to(device)
                logits = model.forward(g, '_N')
                labels = batch.label
                loss = criterion(logits, labels)
                f1 = evaluate_f1_score(logits.detach().cpu().numpy(), labels.detach().cpu().numpy())
                val_loss.append(loss.item())
                val_f1.append(f1)

        val_loss = np.mean(val_loss) 
        val_f1 = np.mean(val_f1)
        print('Epoch {:d} | Train Loss {:.4f} | Train F1 {:.4f} | Val Loss {:.4f} | Val F1 {:.4f} |'.format(epoch + 1, train_loss, train_f1, val_loss, val_f1))
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            torch.save(model.state_dict(), os.path.join(args.save_dir, args.name))

        if val_f1 < lastf1:
            cnt += 1
            if cnt == args.early_stopping:
                print('Early stop.')
                break
        else:
            cnt = 0
            lastf1 = val_f1

    model.eval()
    test_loss = []
    test_f1 = []
    model.load_state_dict(torch.load(os.path.join(args.save_dir, args.name)))
    with torch.no_grad():
        for batch in test_set:
            g = batch.graph
            g = g.to(device)
            logits = model.forward(g, '_N')
            labels = batch.label
            loss = criterion(logits, labels)
            f1 = evaluate_f1_score(logits.detach().cpu().numpy(), labels.detach().cpu().numpy())
            test_loss.append(loss.item())
            test_f1.append(f1)
    test_loss = np.mean(test_loss) 
    test_f1 = np.mean(test_f1)

    print("Test F1: {:.4f} | Test loss: {:.4f}".format(test_f1, test_loss))



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='GNN-FiLM')
    parser.add_argument("--dataset", type=str, default="PPI", help="DGL dataset for this GNN-FiLM")
    parser.add_argument("--gpu", type=int, default=-1, help="GPU Index. Default: -1, using CPU.")
    parser.add_argument("--in_size", type=int, default=50, help="Input dimensionalities")
    parser.add_argument("--hidden_size", type=int, default=320, help="Hidden layer dimensionalities")
    parser.add_argument("--out_size", type=int, default=121, help="Output dimensionalities")
    parser.add_argument("--num_layers", type=int, default=4, help="Number of GNN layers")
    parser.add_argument("--batch_size", type=int, default=5, help="Batch size")
    parser.add_argument("--max_epoch", type=int, default=1500, help="The max number of epoches. Default: 500")
    parser.add_argument("--early_stopping", type=int, default=80, help="Early stopping. Default: 50")
    parser.add_argument("--lr", type=float, default=0.001, help="Learning rate. Default: 3e-1")
    parser.add_argument("--wd", type=float, default=0.0009, help="Weight decay. Default: 3e-1")
    parser.add_argument('--step-size', type=int, default=40, help='Period of learning rate decay.')
    parser.add_argument('--gamma', type=float, default=0.8, help='Multiplicative factor of learning rate decay.')
    parser.add_argument("--dropout", type=float, default=0.1, help="Dropout rate. Default: 0.9")
    parser.add_argument('--save_dir', type=str, default='./out', help='Path to save the model.')
    parser.add_argument("--name", type=str, default='GNN-FiLM', help="Saved model name.")

    args = parser.parse_args()
    print(args)
    if not os.path.exists(args.save_dir):
        os.mkdir(args.save_dir)
    main(args)

